import os
import json

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

def barplot(data, 
            err=None,
            avg_parents=True,
            avg_children=True,
            parents=None,
            children=None,
            color=None,
            width=None,
            ylabel=None,
            xlabel=None,
            ylim = None,
            title=None,
            suptitle=None,
            show=True,
            space_before_avg = False,
            figsize=[6.4, 4.8],
            fontsize = 12,
            fname=None):
    """
    Parameters
    ==========
    data : array-like, shape of (n_parents, n_children)
    err : None, array, default=None
        shape of (n_parents, n_children)
    avg : boolen, default=True
    fname : path to file to save or None, default=None
        if not None, save figure as a file.
    """

    n_parents, n_children = data.shape
    if parents is None:
        parents = [str(i) for i in range(1, n_parents+1)]

    if children is None:
        children = [str(i) for i in range(1, n_children+1)]        
    
    if avg_parents:
        if space_before_avg:
            data = np.vstack((data, np.zeros((1, n_children))))
            n_parents += 1
            parents = parents.copy()
            parents.append('')
        data = np.vstack((data, np.mean(data, axis=0, keepdims=True)))
        n_parents += 1
        parents = parents.copy()
        parents.append('avg')
        
    if avg_children:
        data = np.hstack((data, np.mean(data, axis=1, keepdims=True)))
        n_children += 1
        children = children.copy()
        children.append('avg')

    if color is None:
        color = [None for i in range(n_children)]
    
    if width is None:
        ratio = 0.9
        width = ratio*0.5*2/n_children

    x = np.arange(n_parents)

    fig, ax = plt.subplots(figsize=figsize)
    for idx, child in enumerate(children):
        ax.bar(x-(n_children*width/2)+(idx*width), data[:, idx], width, label=child, color=color[idx])
    if suptitle is not None:
        fig.suptitle(suptitle)
    if title is not None:
        ax.set_title(title)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize = fontsize)
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize = fontsize)
    if ylim is not None:
        plt.ylim(ylim)
    ax.set_xticks(x, parents)
    ax.xaxis.set_tick_params(labelsize=fontsize)
    ax.yaxis.set_tick_params(labelsize=fontsize)
    ax.legend(children, fontsize = fontsize, bbox_to_anchor= (1, 1))
    fig.tight_layout()

    if fname is not None:
        plt.savefig(fname)

    if show:
        plt.show()
    
    return fig


repository_base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
base = os.path.join(repository_base, "results")

sessions = list()
sessions.append("sub-A")
sessions.append("sub-B")
sessions.append("sub-C")
sessions.append("sub-D")
sessions.append("sub-E")
sessions.append("sub-F")
sessions.append("sub-G")
sessions.append("sub-H")
sessions.append("sub-I")
sessions.append("sub-J")
sessions.append("sub-K")

data = list()
for session in sessions:
    sub_data = list()
    with open(os.path.join(base, "%s_classification_scores.json"%session), "r") as f:
        json_data = json.load(f)
    for m in range(1, 4):
        sub_data.append(json_data['%d'%m]['report']['accuracy'])
    data.append(sub_data)
    
data = np.array(data)
data = np.hstack((data, np.mean(data, axis = 1).reshape(-1, 1)))
data = np.vstack((data, np.mean(data, axis = 0).reshape(1, -1)))
print(data)

children = list()
for session in sessions:
    children.append(session.split('-')[1])
children.append('Avg.')
print(children)

parents = ['Stream 1', 'Stream 2', 'Stream 3', 'Average']
barplot(data = data.T, avg_parents=False, avg_children=False, children=children, parents = parents, ylabel = "Accuracy", show = False)
#plt.annotate('test', ha='center', va='center')
#plt.show()

print(data)
print(data[:, -1])

fontsize = 20

plt.figure(figsize=[7, 8])
colors = list()
for m in range(11):
    colors.append("tab:purple")
colors.append("tab:pink")
plt.bar(x = range(12), height = data[:,-1], label = children, color = colors)
plt.xlabel("Subject", fontsize=fontsize)
plt.ylabel("Accuracy", fontsize = fontsize)
plt.xticks(range(12), children, fontsize=fontsize)
plt.yticks(np.array(list(range(10)))*0.1, fontsize = fontsize)
plt.tight_layout()
plt.savefig(os.path.join(base, "classification_barplot_each_subject.png"), dpi=300)

plt.figure(figsize=[7, 8])
colors = list()
labels = ["Stream 1", "Stream 2", "Stream 3", "Average"]
for m in range(3):
    colors.append("tab:purple")
colors.append("tab:pink")
plt.bar(x = range(4), height = data[-1,:], label = labels, color = colors)
plt.xlabel("", fontsize=fontsize)
plt.ylabel("Accuracy", fontsize = fontsize)
plt.xticks(range(4), labels, fontsize=fontsize)
plt.yticks(np.array(list(range(9)))*0.1, fontsize = fontsize)
plt.tight_layout()
plt.savefig(os.path.join(base, "classification_barplot_each_stream.png"), dpi = 300)